#' Custom two-class summary function
#' 
#' Function used to compute performance metrics when running \code{\link[caret]{train}}.
#'     The following metrics are returned as a named numeric vector:
#'     \describe{
#'         \item{Accuracy}{(TP+TN)/N}
#'         \item{AccuracyNull}{Prevalence of "positive" class}
#'         \item{AccuracyPValue}{p-value of \code{Accuracy} compared to \code{AccuracyNull}}
#'         \item{Balanced Accuracy}{(\code{Sensitivity}+\code{Specificity})/2}
#'         \item{Precision}{TP/(TP+FP) ('How many instance \emph{labeled positive} are correctly classified?')}
#'         \item{Recall}{TP/(TP+FN) ('How many of \emph{truly positive} instances are labeled correctly?')}
#'         \item{Sensitivity}{TP/(TP+FN) = \code{Recall} (true-positive rate)}
#'         \item{Specificity}{TN/(TN+FP) = 1/\code{Recall} (inverse Recall, true-negative rate)}
#'         \item{Kappa}{}
#'         \item{logLoss}{negative log-likelihood of the binomial distribution}
#'         \item{AUC}{Area under the Receiver Operating Characteristic (ROC) curve}
#'         \item{PR-AUC}{Area under the Precision-Recall ROC curve}
#'         \item{F0.5}{F-measure (see Notes) with β = .5 (twice as much weight on Precision as on Recall)}
#'         \item{F1}{F-measure (see Notes) with β = 1 (Precision and Recall weighted equally)}
#'         \item{F2}{F-measure (see Notes) with β = 2 (twice as much weight on Recall as on Precision)}
#'     }
#' @note The F-measure is computed as (1+β²) x (Precision x Recall)/(β²xPrecision + Recall) 
#' 
#' @param data a data frame with columns \code{obs} and \code{pred} for the observed and predicted outcomes, 
#'     and columns with predicted probabilities for each outcome class. 
#'     See the \code{classProbs} argument to \code{\link[caret]{trainControl}}.
#'
#' @param lev a character vector of factors levels for the response.
#'     First element is passed to \code{\link[caret]{confusionMatrix}}'s \code{positive}.
#' 
#' @param model	a character string for the model name 
#'     (as taken from the \code{method} argument of \code{\link[caret]{train}}.)
#'     
#' @param pred A vector of numeric data (could be a factor)
#' 
#' @param obs A vector of numeric data (could be a factor)
#' 
#' @example 
#' \dontrun{
#' library(dplyr)
#' dat <- data.frame(
#'   pred = sample(1:2, 10, replace = T)
#'   , obs = sample(1:2, 10, replace = T)
#' ) %>% 
#'   mutate(
#'     `1` = ifelse(pred == 1, sample(seq(.51, .99, length.out = 100), 10), sample(seq(0.01, .49, length.out = 100), 10))
#'     , `2` = 1-`1`
#'   ) %>% 
#'   mutate_at(1:2, as.factor)
#' superSumFun(dat, levels(dat$obs))
#' } 
superSumFun <- function(data, lev = NULL, model = NULL) {
  if (inherits(data, "tbl_df"))
    data <- as.data.frame(data)

  
  lvls <- levels(data$obs)

  if (length(lvls) > 2) 
    stop(paste("Your outcome has", length(lvls), "levels. summary_fun() can only handle two outcome classes."), call. = FALSE)
  
  if (!all(lev %in% colnames(data)))
    stop("`data` must contain columns recording class probabilities. Set `classProbs = TRUE` in `trainControl`.", call. = FALSE)
  
  caret:::requireNamespaceQuietStop("ModelMetrics")
  caret:::requireNamespaceQuietStop("MLmetrics")
  
  if (!all(levels(data[, "pred"]) == lvls)) 
    stop("levels of observed and predicted data do not match", call. = FALSE)
  
  
  lloss <- caret::mnLogLoss(data = data, lev = lev, model = model)
  rocAUC <- ModelMetrics::auc(ifelse(data$obs == lev[2], 0, 1), data[, lvls[1]])
  pr_auc <- try(MLmetrics::PRAUC(y_pred = data[, lev[1]], y_true = ifelse(data$obs == lev[1], 1, 0)), silent = TRUE)
  if (inherits(pr_auc, "try-error")) pr_auc <- NA
  
  CM <- caret::confusionMatrix(data$pred, data$obs, positive = lev[1], mode = "everything")
  
  out <- c(
    CM$overall[c("Accuracy", "AccuracyNull", "AccuracyPValue")]
    , CM$byClass[c("Balanced Accuracy", "Precision", "Recall", "Sensitivity", "Specificity")]
    , CM$overall["Kappa"]
    , lloss 
    , c(
      "AUC" = rocAUC
      , "PR-AUC" = pr_auc
      , "F0.5" = caret:::F_meas.default(data = data$pred, reference = data$obs, relevant = lev[1], beta = 0.5)
      , "F1" = caret:::F_meas.default(data = data$pred, reference = data$obs, relevant = lev[1], beta = 1)
      , "F2" = caret:::F_meas.default(data = data$pred, reference = data$obs, relevant = lev[1], beta = 2)
    )
  )
  
  return(out)
}


#' Custom multi-class summary function
#' 
#' Function used to compute performance metrics when running \code{\link[caret]{train}}.
#'     The following metrics are returned as a named numeric vector:
#'     \describe{
#'         \item{Accuracy}{(TP+TN)/N}
#'         \item{AccuracyNull}{Prevalence of most frequent label class}
#'         \item{Kappa}{}
#'         \item{N}{entries of confusion matrix (concatenated column-wise)}
#'         \item{logLoss}{negative log-likelihood of the binomial distribution}
#'         \item{AUC}{Area under the Receiver Operating Characteristic (ROC) curve}
#'         \item{PR-AUC}{Area under the Precision-Recall ROC curve}
#'         \item{Prevalence}{(class-specific) prevalence}
#'         \item{Balanced Accuracy}{(\code{Sensitivity}+\code{Specificity})/2}
#'         \item{Precision}{TP/(TP+FP) ('How many instance \emph{labeled positive} are correctly classified?')}
#'         \item{Recall}{TP/(TP+FN) ('How many of \emph{truly positive} instances are labeled correctly?')}
#'         \item{Sensitivity}{TP/(TP+FN) = \code{Recall} (true-positive rate)}
#'         \item{Specificity}{TN/(TN+FP) = 1/\code{Recall} (inverse Recall, true-negative rate)}
#'         \item{F1}{F-measure (see Notes) with β = 1 (Precision and Recall weighted equally)}
#'     }
#' For all but the "Accuracy", "AccuracyNull", "Kappa" and "logLoss", 
#'    cross-class means ("Mean_*"), 
#'    class prevalence-weighted means ("WMean_*"),
#'    and class specific metrics are reported (class labels appended separated by "--").
#'    The exception is "N", "Prevalence" and, for which no (weighted) mean is reported.
#' 
#' @note The F-measure is computed as (1+β²) x (Precision x Recall)/(β²xPrecision + Recall) 
#' 
#' @param data a data frame with columns \code{obs} and \code{pred} for the observed and predicted outcomes, 
#'     and columns with predicted probabilities for each outcome class. 
#'     See the \code{classProbs} argument to \code{\link[caret]{trainControl}}.
#'
#' @param lev a character vector of factors levels for the response.
#'     First element is passed to \code{\link[caret]{confusionMatrix}}'s \code{positive}.
#'     Default is is \code{NULL}.
#' 
#' @param model	a (optional) character string for the model name 
#'     (as taken from the \code{method} argument of \code{\link[caret]{train}}.)
#'     Default is is \code{NULL}. 
#'     
#' @example \dontrun{
#' data <- data.frame(
#'   obs = c("a", "b", "b", "c")
#'   , pred = c("a", "a", "b", "c")
#'   , a = c(.8, .6, .1, .1)
#'   , b = c(.1, .3, .8, .1)
#'   , c = c(.1, .1, .1, .8)
#' )
#' 
#' lev <- letters[1:3]
#' data$obs <- factor(data$obs, lev, lev)
#' data$pred <- factor(data$pred, lev, lev)
#' 
#' multiClassSum(data, lev)
#' }
multiClassSum <- function (data, lev = NULL, model = NULL) {
  stopifnot(
    "`data` needs to have columns 'obs' (factor) and 'pred' (factor)." = all(c("pred", "obs") %in% colnames(data))
    , "`data$obs` needs to be a factor." = is.factor(data$obs)
    , "`data$pred` needs to be a factor." = is.factor(data$pred)
  )
  lvls <- levels(data[, "pred"])
  w <- as.vector(prop.table(table(data[, "obs"]))[lvls])
  if (!all(lvls == levels(data[, "obs"]))) 
    stop("levels of observed and predicted data do not match")
  has_class_probs <- all(lev %in% colnames(data))
  if (has_class_probs) {
    lloss <- caret:::mnLogLoss(data = data, lev = lev, model = model)
    caret:::requireNamespaceQuietStop("pROC")
    caret:::requireNamespaceQuietStop("MLmetrics")
    prob_stats <- lapply(lvls, function(x) {
      # dichotomize observations (focal class vs. others)
      obs <- ifelse(data[, "obs"] == x, 1, 0)
      # get predicted probabilities of focal class
      prob <- data[, x]
      # comput area under the Receiver-Operator curve
      roc_auc <- try(pROC::roc(obs, data[, x], direction = "<", quiet = TRUE), silent = TRUE)
      roc_auc <- if (inherits(roc_auc, "try-error")) NA else roc_auc$auc
      # comput area under the Precision-Recall curve
      pr_auc <- try(MLmetrics::PRAUC(y_pred = data[, x], y_true = obs), silent = TRUE)
      if (inherits(pr_auc, "try-error")) pr_auc <- NA
      res <- c("AUC" = roc_auc, "PR-AUC" = pr_auc)
      return(res)
    })
    prob_stats <- do.call("rbind", prob_stats)
    rownames(prob_stats) <- lvls
    prob_stats_means <- colMeans(prob_stats, na.rm = TRUE)
    names(prob_stats_means) <- paste("Mean_", names(prob_stats_means))
    
    prob_stats_wmeans <- apply(prob_stats, 2, weighted.mean, w = w, na.rm = T)
    names(prob_stats_wmeans) <- paste("WMean_", names(prob_stats_wmeans))
  }
  
  CM <- caret::confusionMatrix(data[, "pred"], data[, "obs"], mode = "everything")
  overall_stats <- CM$overall[c("Accuracy", "AccuracyNull", "Kappa")]
  counts <- as.vector(CM$table)
  names(counts) <- paste0("N_", rep(lvls, each = length(lvls)), "--", rep(lvls, times = length(lvls)))
  overall_stats <- c(overall_stats, counts)
  
  if (has_class_probs) {
    overall_stats <- c(
      overall_stats
      , logLoss = as.numeric(lloss)
      , prob_stats_means[1]
      , prob_stats_wmeans[1]
      , setNames(prob_stats[,1], paste0(colnames(prob_stats)[1], "__", lvls))
      , prob_stats_means[2]
      , prob_stats_wmeans[2]
      , setNames(prob_stats[,2], paste0(colnames(prob_stats)[2], "__", lvls))
    )
  }
  
  CM$byClass <- CM$byClass[, c("Prevalence", "Balanced Accuracy", "Precision", "Recall", "Sensitivity", "Specificity", "F1")]
  if (length(lvls) == 2) {
    class_stats <- CM$byClass
  } else {
    
    class_stats_means <- colMeans(CM$byClass, na.rm = TRUE)
    class_stats_wmeans <- apply(CM$byClass, 2, weighted.mean, w = w, na.rm = T)
    
    class_stats <- double()
    m <- length(lvls) + 2L
    for (c in 1:ncol(CM$byClass)) {
      l <- length(class_stats)
      class_stats[(l+1):(l+m)] <- c(class_stats_means[c], class_stats_wmeans[c], CM$byClass[, c])
    }
    names(class_stats) <- paste0(
      rep(c("Mean_", "WMean_", rep("", length(lvls))), times = ncol(CM$byClass))
      , rep(colnames(CM$byClass), each = m)
      , rep(c("", "", paste0("__", lvls)), times = ncol(CM$byClass))
    )
    class_stats <- class_stats[-c(1:2)]
  }
  
  stats <- c(overall_stats, class_stats)
  names(stats) <- gsub("\\s+", "", names(stats))
  
  return(stats)
}
# multiClassSum(data, lev)


#' Parse the training results from multi-class training
#' 
#' @param x a \code{caret} 'train' object or its results data frame (element 'results')
#' @param param.col.idxs integer indexes of parameter columns 
#'     (usually \code{1:k}, where \code{k} is the number of tuning parameters)
#' @param labs character vector of outcome class labels used during training
#' @param metrics a character vector specifying the metrics to be retained.
#'     Defaults to cross-class F1 and Balanced Accuracy, and class-specific precision, sensitivity and specificity.
parse_multiclass_train_results <- function(
  x
  # , param.col.idxs = 1:2
  , param.col.idxs
  # , labs = altlabels
  , labs
  , metrics = c(
    "Mean_F1", "Mean_BalancedAccuracy"
    , paste0("Prevalence__", labs)
    , paste0("Precision__", labs)
    , paste0("Sensitivity__", labs)
    , paste0("Specificity__", labs)
  )
) {
  if (inherits(x, "train")) x <- x$results
  x %>% 
    select(-ends_with("SD")) %>% 
    pivot_longer(-all_of(param.col.idxs)) %>% 
    filter(name %in% metrics) %>% 
    tidyr::extract(name, c("what", "metric", "class"), regex = "^(W?Mean_|)([^_]+)(__\\w+|)$") %>% 
    mutate(
      class = ifelse(class == "", "all", sub("^__", "", class))
      , what = case_when(
        what == "" ~ "class-wise"
        , what == "Mean_" ~ "cross-class mean"
        , what == "WMean_" ~ "cross-class weighted mean"
        , TRUE ~ NA_character_
      )
    )
}


